rm(list=ls())
library(tidyverse)
library(FactorHet)
library(mvtnorm)
library(ggpubr)
# Given that the simulations take some time to run, a formatted version of the
# output is provided in RDS files; see README.md for more discussion

############## FIGURES A3 and A4 ###############################################

# Load in the true simulated data and show the distribution of
# true effects and AMCE

set.seed(123)
true_values <- readRDS('code/bs_true_values.RDS')
range(true_values$true_AME$mean)

g_true_AME <- ggplot() +
  geom_histogram(
    aes(x=true_values$true_AME %>% filter(level != 1) %>% pull(mean))) +
  theme_bw() + xlab('AMCE') + ylab('Count')

g_true_beta <- ggplot() +
  geom_histogram(aes(x=as.vector(do.call('rbind', true_values$true_beta)))) +
  theme_bw() + xlab('Beta') + ylab('Count')

g_true_vis <- ggpubr::ggarrange(g_true_beta + ggtitle('(a) Beta'), 
                                g_true_AME + ggtitle('(b) AMCE'))  
# Figure A3
ggsave(filename = 'figures/sim_truth.pdf', 
       plot = g_true_vis, width = 8.5, height = 8.5/2)

sim_moderator <- cbind(1, 
  mvtnorm::rmvnorm(n = 10^6, sigma = toeplitz(0.25^(0:(5-1)))))
sim_pi <- FactorHet:::softmax_matrix(
  sim_moderator %*% true_values$true_phi
)

g_vis_pi <- data.frame(sim_pi) %>%
  pivot_longer(cols = everything(), names_pattern = '([0-9])') %>%
  ggplot() + 
  geom_density(aes(x=value, group = name)) + 
  facet_wrap(~paste0('Group ', name)) + xlab('Probability') +
  ylab('Density') + theme_bw()

# Figure A4
ggsave(filename = 'figures/sim_pi.pdf', plot = g_vis_pi,
  width = 8.5, height = 8.5/2)

print(colMeans(sim_pi))
max_pi <- apply(sim_pi, MARGIN = 1, max)
quantile(max_pi)

########### FIGURES 2, A5-A12, TABLE A2 ########################################

# Load in the formatted results of the simulations and prepare the output
cjoint_sims <- readRDS('final_output/final_cjoint.RDS')
ame_sims <- readRDS('final_output/final_AME.RDS')

ame_sims <- ame_sims %>% mutate(orig_samplesize = sample_size, 
  sample_size = case_when(
    sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
    sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
    sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
))
cjoint_sims <- cjoint_sims %>% mutate(orig_samplesize = sample_size, 
  sample_size = case_when(
    sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
    sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
    sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
))

summary_cjoint <- cjoint_sims %>%
  group_by(type, factor, level, group, sample_size, orig_samplesize, truth) %>%
  summarize(
    coverage_all = mean(coverage),
    coverage_ps = mean(coverage[selected]),
    bias = mean(coef - truth),
    mae = mean(abs(coef - truth)),
    median_coef = median(coef),
    median_bias = median(coef - truth),
    avg_coef = mean(coef),
    rmse = sqrt(mean( (coef - truth)^2 )),
    median_se = median(sqrt(ifelse(var < 0, 0, var))),
    mean_se = mean(sqrt(ifelse(var < 0, 0, var))),
    sd_of_estimates = sd(coef),
    mean_ad_of_estimates = mean(abs(coef - mean(coef))),
    median_ad_of_estimates = median(abs(coef - median(coef))),
    iqr_of_estimates = IQR(coef)) %>%
  ungroup %>% 
  mutate(true_zero = ifelse(truth == 0, 'Zero\nTrue Effect', 'Non-Zero\nTrue Effect'))


summary_ame <- ame_sims %>%
  group_by(type, factor, level, group, sample_size, orig_samplesize, truth) %>%
  summarize(
    coverage_all = mean(coverage),
    coverage_ps = mean(coverage[selected]),
    median_error = median(abs(marginal_effect - truth)),
    bias = mean(marginal_effect - truth),
    mae = mean(abs(marginal_effect - truth)),
    rmse = sqrt(mean( (marginal_effect - truth)^2 )),
    avg_coef = mean(marginal_effect),
    median_coef = median(marginal_effect),
    median_bias = median(marginal_effect - truth),
    median_se = median(sqrt(ifelse(var < 0, 0, var))),
    mean_se = mean(sqrt(ifelse(var < 0, 0, var))),
    sd_of_estimates = sd(marginal_effect),
    mean_ad_of_estimates = mean(abs(marginal_effect - mean(marginal_effect))),
    median_ad_of_estimates = median(abs(marginal_effect - median(marginal_effect)))
  ) %>%
  ungroup %>% 
  mutate(true_zero = ifelse(truth == 0, 'Zero\nTrue Effect', 'Non-Zero\nTrue Effect'))

summary_ame <- summary_ame %>%
  mutate(fmt_type = case_when(
    type == 'non_adapt' ~ 'Full Data',
    type == 'ssplit' ~ 'Split Sample',
    TRUE ~ NA_character_
  ))

summary_cjoint <- summary_cjoint %>%
  mutate(fmt_type = case_when(
    type == 'non_adapt' ~ 'Full Data',
    type == 'ssplit' ~ 'Split Sample',
    TRUE ~ NA_character_
  ))

# Figure 2: Simulation results on AMCE

g_sim_ame_est <- ggplot(summary_ame %>% filter(orig_samplesize != 'very_large', type == 'non_adapt')) +
  geom_point(size = 2, aes(x=truth,y=avg_coef, pch = sample_size, col = sample_size)) +
  coord_equal() + theme_bw(base_size = 13) +
  theme(legend.position = 'bottom', aspect.ratio = 1) + xlab('Truth') + ylab('Average Estimate') +
  ggtitle('(a) Estimated Effects') + geom_abline(aes(slope=1,intercept=0)) +
  scale_color_manual(values = c('black', 'red')) + scale_shape_manual(values = c(15, 4)) +
  labs(col = 'Sample Size:', pch = 'Sample Size:')

range_ame_se_plot <- summary_ame %>% 
  filter(orig_samplesize != 'very_large', type == 'non_adapt') %>%
  select(mean_se, sd_of_estimates) %>%
  pivot_longer(cols = everything()) %>%
  pull(value) %>% range

g_sim_ame_se <- ggplot(summary_ame %>% filter(orig_samplesize != 'very_large', type == 'non_adapt')) +
  geom_point(size = 2, aes(x=mean_se,y=sd_of_estimates, pch = sample_size, col = sample_size)) +
  theme_bw(base_size = 13) +
  theme(legend.position = 'bottom', aspect.ratio = 1) + 
  xlab('Average Posterior Standard Deviation') + ylab('Std. Dev of Estimates') +
  ggtitle('(b) Posterior Standard Deviation')  + geom_abline(aes(slope=1,intercept=0)) +
  xlim(range_ame_se_plot) + ylim(range_ame_se_plot) +
  scale_color_manual(values = c('black', 'red')) + scale_shape_manual(values = c(15, 4)) +
  labs(col = 'Sample Size:', pch = 'Sample Size:')

g_plot_ame <- ggpubr::ggarrange(g_sim_ame_est, g_sim_ame_se, common.legend = T, legend = 'bottom')
ggsave(filename = 'figures/sim_ame.pdf', plot = g_plot_ame, width = 8.5, height=  8.5/2)

# Figure A5: Simulation results on \hat{\beta}_k

g_sim_beta_est <- ggplot(
  summary_cjoint %>% filter(orig_samplesize != 'very_large', type == 'non_adapt')) +
  geom_point(size = 2, aes(x=truth,y=avg_coef, pch = sample_size, col = sample_size)) +
  coord_equal() + theme_bw(base_size = 13) +
  theme(legend.position = 'bottom', aspect.ratio = 1) + xlab('Truth') + 
  ylab('Average Estimate') +
  ggtitle('(a) Estimated Effects') + geom_abline(aes(slope=1,intercept=0)) +
  scale_color_manual(values = c('black', 'red')) + scale_shape_manual(values = c(15, 4)) +
  labs(col = 'Sample Size:', pch = 'Sample Size:')

range_beta_se_plot <- summary_cjoint %>% 
  filter(orig_samplesize != 'very_large', type == 'non_adapt') %>%
  select(mean_se, sd_of_estimates) %>%
  pivot_longer(cols = everything()) %>%
  pull(value) %>% range

g_sim_beta_se <- ggplot(summary_cjoint %>% filter(orig_samplesize != 'very_large', type == 'non_adapt')) +
  geom_point(size = 2, aes(x=mean_se,y=sd_of_estimates, pch = sample_size, col = sample_size)) +
  theme_bw(base_size = 13) +
  theme(legend.position = 'bottom', aspect.ratio = 1) + 
  xlab('Average Posterior Standard Deviation') + ylab('Std. Dev of Estimates') +
  ggtitle('(b) Posterior Standard Deviation')  + geom_abline(aes(slope=1,intercept=0)) +
  xlim(range_beta_se_plot) + ylim(range_beta_se_plot) +
  scale_color_manual(values = c('black', 'red')) + scale_shape_manual(values = c(15, 4)) +
  labs(col = 'Sample Size:', pch = 'Sample Size:')

g_plot_beta <- ggarrange(g_sim_beta_est, g_sim_beta_se, common.legend = T, legend = 'bottom')
ggsave(filename = 'figures/sim_beta.pdf', plot = g_plot_beta, width = 8.5, height=  8.5/2)

# Examine correlation against truth (referenced in Section 4)
summary_ame %>%
  group_by(type, sample_size) %>%
  summarize(cor = cor(avg_coef, truth)) %>%
  filter(type == 'non_adapt') %>% print

# Figure A6: Show the RMSE, Bias, and Coverage for Split-Sample and 
# Larger Sample Size for AMCE

# Top Panel (Figure A6.A)
app_g_dist_ame <- ggplot(
  summary_ame %>% filter(!is.na(fmt_type)) %>% ungroup %>%
    mutate(sample_size = factor(sample_size, levels = rev(sort(unique(summary_ame$sample_size))))) %>%
    select(fmt_type, factor, level, group, truth, true_zero, sample_size, rmse, bias, coverage_all) %>%
    reshape2::melt(id = c('fmt_type', 'sample_size', 'true_zero', 'truth', 'factor', 'level', 'group')) %>%
    mutate(variable = recode_factor(variable, 'rmse' = 'RMSE', 'bias' = 'Bias', 'coverage_all' = 'Coverage'))
) + geom_boxplot(aes(x=value, y = sample_size, col = fmt_type, group = interaction(sample_size, fmt_type))) +
  facet_grid(true_zero ~ variable, scale = 'free_x') +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0.95, 0, 0), variable = factor(c('Coverage', 'Bias', 'RMSE'))))  +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0.9), variable = factor('Coverage')), linetype = 'dashed') +
  labs(col = 'Method: ') +
  ylab('Sample Size') + xlab('') + theme_bw() +
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0)) 

ggsave(app_g_dist_ame,
       filename = 'figures/app_perf_splitsample_ame.pdf',
       width = 8.5, height = 8.5/2)

# Top Panel (Figure A6.B)

app_g_dist_beta <- ggplot(
  summary_cjoint %>% filter(!is.na(fmt_type)) %>% ungroup %>%
    mutate(sample_size = factor(sample_size, levels = rev(sort(unique(summary_ame$sample_size))))) %>%
    select(fmt_type, factor, level, group, truth, true_zero, sample_size, rmse, bias, coverage_all) %>%
    reshape2::melt(id = c('fmt_type', 'sample_size', 'true_zero', 'truth', 'factor', 'level', 'group')) %>%
    mutate(variable = recode_factor(variable, 'rmse' = 'RMSE', 'bias' = 'Bias', 'coverage_all' = 'Coverage'))
) + geom_boxplot(aes(x=value, y = sample_size, col = fmt_type, group = interaction(sample_size, fmt_type))) +
  facet_grid(true_zero ~ variable, scale = 'free_x') +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0.95, 0, 0), variable = factor(c('Coverage', 'Bias', 'RMSE'))))  +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0.9), variable = factor('Coverage')), linetype = 'dashed') +
  labs(col = 'Method: ') +
  ylab('Sample Size') + xlab('') + theme_bw() +
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0)) 

ggsave(app_g_dist_beta,
       filename = 'figures/app_perf_splitsample_beta.pdf',
       width = 8.5, height = 8.5/2)

# Summarize size of average, coverage, etc.
summary_cjoint %>% filter(!is.na(fmt_type)) %>%
  group_by(true_zero, fmt_type, sample_size) %>% 
  summarize(avg_mean_se = mean(mean_se), 
            avg_ratio = mean(mean_se/sd_of_estimates), coverage = median(coverage_all))

summary_ame %>% filter(!is.na(fmt_type)) %>%
  group_by(true_zero, fmt_type, sample_size) %>% 
  summarize(avg_mean_se = mean(mean_se), 
            avg_ratio = mean(mean_se/sd_of_estimates), coverage = median(coverage_all))


# Figure A11: Performance of simulations when
# moderators are mis-specified

app_g_plot_nomod_sims <- ggplot(
  summary_ame %>% 
    filter(type %in% c('ssplit', 'non_adapt') |
             grepl(type, pattern='misspec.*(ssplit|non_adapt)')) %>% ungroup %>%
    mutate(fmt_type = case_when(
      type == 'non_adapt' ~ 'Full Data (Correct Moderators)',
      type == 'ssplit' ~ 'Split Sample (Correct Moderators)',
      type == 'misspec@nonlinear_mod@non_adapt' ~ 'Full Data (Non-Linear Transf.)',
      type == 'misspec@nonlinear_mod@ssplit' ~ 'Split Sample (Non-Linear Transf.)',
      type == 'misspec@no_mod@non_adapt' ~ 'Full Data (No Moderators)',
      type == 'misspec@no_mod@ssplit' ~ 'Split Sample (No Moderators)',
      TRUE ~ NA_character_
    )) %>% 
    mutate(
      fmt_type = factor(fmt_type, 
        levels = c('Full Data (Correct Moderators)', 'Full Data (Non-Linear Transf.)',
       'Full Data (No Moderators)', 'Split Sample (Correct Moderators)',
       'Split Sample (Non-Linear Transf.)', 'Split Sample (No Moderators)'))
    ) %>% 
    mutate(sample_size = factor(sample_size, levels = rev(sort(unique(summary_ame$sample_size))))) %>%
    select(fmt_type, factor, level, group, truth, true_zero, sample_size, rmse, bias, coverage_all) %>%
    reshape2::melt(id = c('fmt_type', 'sample_size', 'true_zero', 'truth', 'factor', 'level', 'group')) %>%
    mutate(variable = recode_factor(variable, 'rmse' = 'RMSE', 'bias' = 'Bias', 'coverage_all' = 'Coverage'))
) + geom_boxplot(aes(x=value, y = sample_size, col = fmt_type, group = interaction(sample_size, fmt_type))) +
  facet_grid(true_zero ~ variable, scale = 'free_x') +
  scale_color_manual(values = scales::hue_pal()(6)) +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0.95, 0, 0), 
      variable = factor(c('Coverage', 'Bias', 'RMSE'))))  +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0.9),
      variable = factor('Coverage')), linetype = 'dashed') +
  labs(col = 'Method: ') +
  ylab('Sample Size') + xlab('') + theme_bw() +
  guides(color = guide_legend(nrow = 2, byrow = TRUE)) + 
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0)) 

ggsave(app_g_plot_nomod_sims,
       filename = 'figures/app_perf_nomod_ame.pdf',
       width = 8.5, height = 8.5/2)

# Figure A12: Compare the probability of correct assignment
posterior_sims <- readRDS('final_output/final_posterior.RDS')
plot_posterior <- posterior_sims %>%
  select(type, sample_size, posterior, postpred) %>%
  reshape2::melt(id = c('type', 'sample_size')) %>%
  mutate(orig_samplesize = sample_size, sample_size = case_when(
    sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
    sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
    sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
  )) %>% ungroup %>%
  mutate(fmt_type = case_when(
    type == 'non_adapt' ~ 'Full Data (Correct Moderators)',
    type == 'misspec@nonlinear_mod@non_adapt' ~ 'Full Data (Non-Linear Transf.)',
    type == 'misspec@no_mod@non_adapt' ~ 'Full Data (No Moderators)',
    type == 'ssplit' ~ 'Split Sample (Correct Moderators)',
    type == 'misspec@nonlinear_mod@ssplit' ~ 'Split Sample (Non-Linear Transf.)',
    type == 'misspec@no_mod@ssplit' ~ 'Split Sample (No Moderators)',
    TRUE ~ NA_character_
  )) %>% 
  mutate(
    fmt_type = factor(fmt_type, levels = c('Full Data (Correct Moderators)', 'Full Data (Non-Linear Transf.)',
                                           'Full Data (No Moderators)', 'Split Sample (Correct Moderators)',
                                           'Split Sample (Non-Linear Transf.)', 'Split Sample (No Moderators)'))
  ) %>% 
  mutate(sample_size = factor(sample_size, levels = rev(sort(unique(summary_ame$sample_size))))) 

plot_posterior <- plot_posterior %>%
  mutate(fmt_variable = case_when(
    variable == 'posterior' ~ 'Posterior',
    variable == 'postpred' ~ 'Posterior Predictive'
  ))

app_g_plot_posterior_nomoderator <- ggplot(plot_posterior %>% filter(!is.na(fmt_type))) + 
  geom_boxplot(aes(x=sample_size,y=value, col = fmt_type, group=interaction(sample_size, fmt_type))) +
  coord_flip(ylim = c(0,1)) + facet_grid(fmt_variable ~ .) +
  theme_bw() +
  guides(color = guide_legend(nrow = 2, byrow = TRUE)) + 
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0))  +
  labs(col='') + xlab('Sample Size') +
  ylab('Average Value in Sampled Group')

ggsave(app_g_plot_posterior_nomoderator,
       filename = 'figures/app_perf_nomod_posterior.pdf',
       width = 8.5, height = 8.5/2)

# Figures on heterogeneous effects (Figure A7-A8)

# Load underlying data
hte_binscatter <- readRDS('final_output/final_HTE_binscatter.RDS')
hte_sims <- readRDS('final_output/final_HTE.RDS')
hte_AME_sims <- readRDS('final_output/final_HTE_by_AME.RDS')

# Figure A7: Compare binned scatter estimates of HTE
plot_binscatter <- hte_binscatter %>% 
  mutate(orig_samplesize = sample_size, sample_size = case_when(
    sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
    sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
    sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
  )) %>%
  mutate(lower = as.numeric(str_extract(group, pattern='(?<=^\\()[-\\.0-9]+'))) %>%
  mutate(upper = as.numeric(str_extract(group, pattern='[-\\.0-9]+(?=\\]$)'))) %>%
  mutate(mid = (lower + upper)/2) %>%
  group_by(mid, method, sample_size) %>%
  summarize(raw_mean = mean(est), est = sum(est * n)/sum(n))

plot_wrong_K <- plot_binscatter %>%
  filter(grepl(method, pattern='wrong|^non_adapt|^ssplit')) %>%
  mutate(
    K = case_when(
      method %in% c("non_adapt", "ssplit") ~ 3,
      method %in% c('wrong_1_full', 'wrong_1_split') ~ 1,
      method %in% c('wrong_2_full', 'wrong_2_split') ~ 2,
      method %in% c('wrong_4_full', 'wrong_4_split') ~ 4
    ),
    K = factor(paste0('K=',K), levels = c('K=1', 'K=2', 'K=3', 'K=4', 'K=0'))
  ) %>%
  mutate(
    split_sample = case_when(
      method %in% c('non_adapt', 'wrong_2_full', 'wrong_1_full', 'wrong_4_full') ~ 'Full Data',
      method %in% c('ssplit', 'wrong_2_split','wrong_1_split', 'wrong_4_split') ~ 'Split Sample'
    ) 
  )

g_plot_wrong_K <- ggplot(
  plot_wrong_K
) + geom_point(aes(x=mid,y=est,col=factor(K))) +
  facet_grid(split_sample ~ sample_size) + coord_equal() +
  geom_abline() + theme_bw() +
  xlab('Truth') + ylab('Estimate') +
  theme_bw() + 
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0)) +
  labs(col='')

ggsave(g_plot_wrong_K,
       filename = 'figures/app_plot_hte_wrong_k.pdf',
       width = 8.5, height = 8.5/2)

# Figure A8: Quality of HTE estimates across sample size, model, and K

hte_sims <- hte_sims %>% mutate(orig_samplesize = sample_size, sample_size = case_when(
  sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
  sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
  sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
))

reshape_hte_sims <- hte_sims %>% mutate(abs_bias = abs(bias)) %>%
  select(method, sim, sample_size, rmse, bias) %>%
  pivot_longer(cols = -c(method, sample_size, sim), names_to = 'variable') %>%
  filter(grepl(method, pattern='wrong|^non_adapt|^ssplit')) %>%
  mutate(
    K = case_when(
      method %in% c("non_adapt", "ssplit") ~ 3,
      method %in% c('wrong_1_full', 'wrong_1_split') ~ 1,
      method %in% c('wrong_2_full', 'wrong_2_split') ~ 2,
      method %in% c('wrong_4_full', 'wrong_4_split') ~ 4,
      TRUE ~ 0
    ),
    K = factor(paste0('K=',K), levels = c("K=1", 'K=2', 'K=3', 'K=4', 'K=0')),
  ) %>%   
  mutate(method_type = case_when(
    method == 'cjbart' ~ 'cjbart',
    grepl(method, pattern='split') ~ 'Split Sample',
    grepl(method, pattern='full|non') ~ 'Full Data'
  )) %>%
  mutate(variable = recode_factor(variable, '|Bias|' = 'abs_bias', 'mae' = 'MAE',
                                  'rmse' = 'RMSE', 'bias' = 'Marginalized Error', 'coverage' = 'Coverage')) %>%
  mutate(sample_size = factor(sample_size, levels = rev(sort(unique(summary_ame$sample_size))))) 

g_dist_HTE <- ggplot() + 
  geom_boxplot(aes(x=value, y = sample_size, col = K, group = interaction(sample_size,method,K)),
               data = reshape_hte_sims) +
  facet_grid(method_type ~ variable, scale = 'free_x') +
  geom_vline(aes(xintercept=x), data = data.frame(x = c(0, 0), variable = factor(c('Marginalized Error', 'RMSE'))))  +
  labs(col = '') +
  ylab('Sample Size') + xlab('') + theme_bw() +
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0)) 

ggsave(g_dist_HTE,
       filename = 'figures/app_rmse_sims.pdf',
       width = 8.5, height = 8.5/2)

reshape_hte_sims %>% 
  group_by(method_type, method, K, sample_size, variable) %>%
  summarize(value = mean(value)) %>%
  filter(variable == 'RMSE') %>%
  reshape2::dcast(method_type + K ~ sample_size, value.var = 'value') %>%
  group_by(method_type) %>% mutate(across(where(is.numeric), ~ ./mean(.[K == 'K=3']))) %>%
  print

# Figure A9: Estimate bias in averaged CAMCEs vs AMCE

hte_AME_sims <- hte_AME_sims %>% mutate(orig_samplesize = sample_size, 
  sample_size = case_when(
    sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
    sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
    sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
))

sum_hte_AME_sims <- hte_AME_sims %>% 
  group_by(factor, level, method, sample_size) %>%
  summarize(across(c(avg_est, avg_truth), mean)) %>%
  filter(grepl(method, pattern='wrong|^non_adapt|^ssplit')) %>%
  mutate(
    K = case_when(
      method %in% c("non_adapt", "ssplit") ~ 3,
      method %in% c('wrong_1_full', 'wrong_1_split') ~ 1,
      method %in% c('wrong_2_full', 'wrong_2_split') ~ 2,
      method %in% c('wrong_4_full', 'wrong_4_split') ~ 4,
      TRUE ~ 0
    ),
    K = factor(paste0('K=',K), levels = c("K=1", 'K=2', 'K=3', 'K=4', 'K=0')),
  ) %>%   
  mutate(method_type = case_when(
    method == 'cjbart' ~ 'cjbart',
    grepl(method, pattern='split') ~ 'Split Sample',
    grepl(method, pattern='full|non') ~ 'Full Data'
  )) 

app_g_recons_AME_CAMCE <- ggplot(sum_hte_AME_sims) + 
  geom_boxplot(aes(x=avg_est-avg_truth,y=K)) + 
  geom_vline(aes(xintercept=0), linetype = 'dashed') +
  facet_grid(method_type ~ sample_size) +
  xlab('Bias in Average Marginal Effect Estimates') +
  theme_bw() +
  theme(legend.position = 'bottom', strip.text.y.right= element_text(angle = 0)) 

ggsave(app_g_recons_AME_CAMCE,
       filename = 'figures/app_recons_AME.pdf',
       width = 8.5, height = 8.5/2)

# Figure A10: Analysis of Between/Total Variation As Explained by Estimated
# Groups

sim_sos <- readRDS('final_output/final_sos.RDS')

# Check that total variation is conserved as "Between + Within"
stopifnot(sim_sos %>% filter(abs(checksum_sos) > 1e-8) %>% nrow == 0)
sum_sos <- sim_sos %>% group_by(type, sample_size, sim) %>%
  summarize(n = n(), across(matches('sos'), sum)) 

plot_sos <- sum_sos %>% 
  mutate(type = case_when(
    type == 'ssplit' ~ 'wrong_3_split',
    type == 'non_adapt' ~ 'wrong_3_full',
    grepl(type, pattern='wrong') ~ type,
    TRUE ~ NA_character_
  )) %>% filter(!is.na(type)) %>%
  filter(grepl(type, pattern='full')) %>%
  mutate(
    ssplit = str_extract(type, pattern='(?<=_)(split|full)'),
    K = str_extract(type, pattern='[0-9]')) %>%
  mutate(orig_samplesize = sample_size, sample_size = case_when(
    sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
    sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
    sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
  ))
# Confirm Total SOS is conserved across different K
stopifnot(plot_sos %>% 
            group_by(ssplit, sim) %>% 
            summarize(diff_range = diff(range(total_sos))) %>%
            pull(diff_range) %>% abs %>% max < 1e-7)

g_sos <- ggplot(plot_sos) +
  geom_boxplot(aes(x=between_sos/total_sos,y=K, group=interaction(K, ssplit))) +
  facet_grid(. ~ sample_size) + labs(col = '') +
  xlab('Between Variability / Total Variability') +
  theme_bw()

ggsave(g_sos,
       filename = 'figures/app_sos_sims.pdf',
       width = 8.5, height = 8.5/2)

# Table A2: Performance of data-driven criterion for selecting K

oos_simulation <- readRDS('final_output/final_OOS.RDS')

oos_simulation <- oos_simulation %>% 
  mutate(
    sample_size = case_when(
      sample_size == 'small' ~ '1,000 People\n(5 Tasks)',
      sample_size == 'large' ~ '2,000 People\n(10 Tasks)',
      sample_size == 'very_large' ~ '4,000 People\n(10 Tasks)'
    )
  )

oos_BIC <- oos_simulation %>%
  group_by(sim, sample_size) %>% 
  top_n(n=1, wt=-BIC) %>% 
  group_by(sample_size, K) %>% 
  tally() %>%
  mutate(K = factor(K, levels = 1:4)) %>%
  group_by(sample_size) %>% mutate(pr = n/sum(n)) %>%
  pivot_wider(id_cols = sample_size, names_from = K,
              values_from = pr, names_sort = TRUE, 
              values_fill = 0,
              names_expand = TRUE)

oos_ll <- oos_simulation %>%
  group_by(sim, sample_size) %>% 
  top_n(n=1, wt=ll) %>% 
  group_by(sample_size, K) %>% 
  tally() %>%
  mutate(K = factor(K, levels = 1:4)) %>%
  group_by(sample_size) %>% mutate(pr = n/sum(n)) %>%
  pivot_wider(id_cols = sample_size, names_from = K,
              values_from = pr, names_sort = TRUE, 
              values_fill = 0,
              names_expand = TRUE)

oos_RMSE <- oos_simulation %>%
  group_by(sim, sample_size) %>% 
  top_n(n=1, wt=-rmse) %>% 
  group_by(sample_size, K) %>% 
  tally() %>%
  mutate(K = factor(K, levels = 1:4)) %>%
  group_by(sample_size) %>% mutate(pr = n/sum(n)) %>%
  pivot_wider(id_cols = sample_size, names_from = K,
              values_from = pr, names_sort = TRUE, 
              values_fill = 0,
              names_expand = TRUE)
out_oos <- oos_BIC %>% 
  apply(MARGIN = 1, FUN=function(i){paste(i, collapse = ' & ')}) %>% 
  paste(., '\\\\') 
out_oos[3] <- paste(out_oos[3], '\\hline\\hline')
writeLines(out_oos, 'figures/oos_sims.tex')

print(oos_RMSE)
